#!/usr/bin/env python3
"""
This scripts generates a markdown table from a given directory that contains all the plots

NOTE: The markdown renders correctly only when you place the images in the same directory.
"""

from pathlib import Path
import pandas as pd
if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("plotdir", help="The directory that contains plot files")
    parser.add_argument("--envs", default=['cartpole-swingup', 'cheetah-run', 'pendulum-swingup', 'walker-walk'], help="list of environments")
    parser.add_argument("--num-expsets", default=5, type=int)
    parser.add_argument("--out", default=None, help="The location of output README.md. Default to plotdir")
    args = parser.parse_args()

    # NOTE: For the main plot, the files look like:
    # - main-plot-cartpole-swingup-0.png
    # - main-plot-cartpole-swingup-1.png
    # - main-plot-cartpole-swingup-2.png
    # - main-plot-cartpole-swingup-3.png
    # - main-plot-cartpole-swingup-4.png
    # - main-plot-cheetah-run-0.png
    # - main-plot-cheetah-run-1.png
    # - main-plot-cheetah-run-2.png
    # - main-plot-cheetah-run-3.png
    # - main-plot-cheetah-run-4.png
    # - main-plot-pendulum-swingup-0.png
    # - main-plot-pendulum-swingup-1.png
    # - main-plot-pendulum-swingup-2.png
    # - main-plot-pendulum-swingup-3.png
    # - main-plot-walker-walk-0.png
    # - main-plot-walker-walk-1.png
    # - main-plot-walker-walk-2.png
    # - main-plot-walker-walk-3.png
    # - main-plot-walker-walk-4.png

    grid = [[None for _ in range(args.num_expsets)] for _ in args.envs]
    for i, env in enumerate(args.envs):
        for j in range(args.num_expsets):
            stem = f'main-plot-{env}-{j}.png'
            fpath = Path(args.plotdir) / stem
            # assert fpath.is_file(), f'file {fpath} does not exist'
            grid[i][j] = fpath if fpath.is_file() else ''


    # Create a table (num_envs x num_expsets)
    df = pd.DataFrame()
    for i, env in enumerate(args.envs):
        row = {
            'environment': env,
            **{j: f"![plot](./{fpath.name})" for j, fpath in enumerate(grid[i]) if fpath}
        }
        df = pd.concat((df, pd.DataFrame([row])))

    mdtxt = df.to_markdown(index=False)

    if args.out is None:
        args.out = Path(args.plotdir) / 'README.md'
    with open(args.out, 'w') as f:
        f.write(mdtxt)
